# Third-Party Imports
import nltk
import openai
from evaluate import load
from datasets import load_dataset
import requests as req
import numpy as np
from sumy.summarizers.lex_rank import LexRankSummarizer
from sumy.summarizers.text_rank import TextRankSummarizer
from sumy.summarizers.lsa import LsaSummarizer
from sumy.summarizers.luhn import LuhnSummarizer
from sumy.parsers.plaintext import PlaintextParser
from sumy.nlp.tokenizers import Tokenizer
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.offline as pyo
pyo.init_notebook_mode()
from dotenv import load_dotenv
load_dotenv()
# Standard Imports
import os
import json
from string import punctuation
# Plotting functions
def plot_bar_data(*bars, x=None, title="", x_label="", y_label=""):
fig = go.Figure(
layout={
"title": title,
"xaxis": {"title": x_label},
"yaxis": {"title": y_label},
"barmode": "group"
}, data=[
go.Bar(name=f"{bar[0]}", x=x, y=bar[1])
for bar in bars
])
return fig
def create_bar(name, data):
return (name, data)
# Tokenization
def tokenize(doc, remove_stopwords=True):
banned = list(punctuation)
if remove_stopwords:
banned += nltk.corpus.stopwords.words("english")
return [
w.lower() for w in nltk.word_tokenize(doc)
if w.lower() not in banned
]
# Document Summariser Class
# Implementation of all NLP methods used by LAME for text summarisation in a single class
class DocSummariser():
def __init__(self):
self._corpus = dict()
def get_corpus(self):
return self._corpus
def load_files(self, corpus):
self._corpus = corpus
def clear_files(self):
self._corpus = dict()
def _word_tokenize(self, text):
banned = list(punctuation) + nltk.corpus.stopwords.words("english")
return [
w for w in nltk.word_tokenize(text)
if w not in banned
]
def _chunk_text(self, text, chunk_len):
chunks = []
current_chunk = ""
sents = nltk.sent_tokenize(text)
for sent in sents:
if len(nltk.word_tokenize(current_chunk + f" {sent}")) >= chunk_len:
chunks.append(current_chunk)
current_chunk = ""
else:
current_chunk += f" {sent}"
chunks.append(current_chunk)
return chunks
def summarise(self, method, fnames, summary_size):
# Build input text
text = " ".join(self._corpus[name] for name in fnames)
# Choose method and return summary
if method == "se":
return self._SE_summary(text, summary_size).strip()
elif method in ("lexR", "texR", "lsa", "luhn"):
return self._algo_summary(text, method, summary_size).strip()
elif method == "bart":
text_chunks = self._chunk_text(text, 400)
return " ".join(
self._BART_summary(chunk, summary_size)
for chunk in text_chunks
).strip()
elif method == "openai":
text_chunks = self._chunk_text(text, 500)
return " ".join(
self._openai_summary(chunk, summary_size)
for chunk in text_chunks
).strip()
def _SE_summary(self, text, summary_size=0.5):
# Create word and sentence tokens
words = self._word_tokenize(text)
word_set = set(words) # set of all unique words in word tokens
sents = nltk.sent_tokenize(text)
# Initialise frequency table for word tokens
w_freq_table = {w: words.count(w) for w in word_set}
# Score sentences based on frequency of their words
sent_scores = {
sent: sum(
w_freq_table.get(w, 0)
for w in self._word_tokenize(sent)
)
for sent in sents
}
# Build summary
multiplier = 2 * (1 - summary_size)
avg = sum(sent_scores.values()) / len(sent_scores)
summary = " ".join(sent for sent in sents if sent_scores[sent] >= avg * multiplier)
return summary
def _algo_summary(self, text, method, summary_size=0.5):
# Get sentence and summary lengths
sent_length = len(nltk.sent_tokenize(text))
summary_len = max(int(summary_size * sent_length), 1)
# Initialise summariser
if method == "lexR":
summariser = LexRankSummarizer()
elif method == "texR":
summariser = TextRankSummarizer()
elif method == "lsa":
summariser = LsaSummarizer()
elif method == "luhn":
summariser = LuhnSummarizer()
# Initialise parser
parser = PlaintextParser(text, Tokenizer("english"))
# Create summary
summary_sents = summariser(parser.document, summary_len)
return " ".join(str(s) for s in summary_sents)
def _BART_summary(self, text, summary_size=0.5):
# Get lengths of original text and summary
word_len = len(nltk.word_tokenize(text))
summary_len = int((summary_size * word_len) + 0.5)
# Get API url and headers
api_url = "https://api-inference.huggingface.co/models/facebook/bart-large-cnn"
headers = {
"Authorization": f"Bearer {os.getenv('HUGGING_FACE_API_KEY')}"
}
payload = {
"inputs": text,
"parameters": {
"do_sample": False,
"max_length": min(round(summary_len + 50, -2), word_len),
"min_length": max(summary_len - 10, 1),
}
}
data = json.dumps(payload)
res = req.request("POST", api_url, headers=headers, data=data)
content = json.loads(res.content.decode("utf-8"))
if isinstance(content, dict):
return content.get("error", "Something's wrong") + "\n"
elif isinstance(content, list):
return content[0].get("summary_text")
def _openai_summary(self, text, summary_size=0.5):
word_len = len(nltk.word_tokenize(text))
summary_len = int((summary_size * word_len) + 0.5)
openai.api_key = os.getenv("OPENAI_API_KEY")
prompt=f"Summarize the following text in no more than {summary_len} words:\n\n{text}\n\nSummary:"
max_tokens = round(summary_len + 50, -2)
if max_tokens < 1: max_tokens = 50
res = openai.Completion.create(
model="text-davinci-003",
prompt=prompt,
temperature=0,
max_tokens=max_tokens,
logprobs=0,
)
summary = res.choices[0].text
return summary
def load_article_data(subset_size=5):
indices = np.random.randint(0, 13368, (subset_size,))
articles = load_dataset(
"cnn_dailymail",
"3.0.0",
split="validation",
).select(indices)
return articles
def summarise_sample(sample, method):
"""
Run a text summarisation method on a single
example from the CNN/DailyMail dataset.
"""
# Get relevant properties from squad sample
article = sample["article"]
highlight = sample["highlights"]
# Initialise doc summariser
doc_summariser = DocSummariser()
# Build and load corpus for doc searcher
doc_summariser.load_files({"Doc": article})
# Get predicted text
summary_text = doc_summariser.summarise(method, ["Doc"], 0.1)
doc_summariser.clear_files()
return summary_text, highlight
def summarise_samples(art_ds, method):
"""
Run a text summarisation method on multiple
examples from the CNN/DailyMail dataset.
"""
# Initialse lists for storing prediction and reference objects
predictions = []
references = []
# Run method on all samples in dataset
for sample in art_ds:
summary, highlight = summarise_sample(sample, method)
predictions.append(summary)
references.append(highlight)
return predictions, references
def evaluate_method(art_ds, method, rouge_metric):
"""
Get the average ROUGE scores of a text summarisation
method after running it on a subset of the CNN/Dailymail
dataset.
"""
# Get prediction and reference objects
preds, refs = summarise_samples(art_ds, method)
# Get results
results = rouge_metric.compute(predictions=preds, references=refs)
return results
def visualise_results(results):
"""
Take results from the evaluate_method function
an create bar graphs to visualise them.
"""
method_labels = {
"se": "Simple Extractive Summarisation",
"lexR": "LexRank Algorithm",
"texR": "TextRank Algorithm",
"lsa": "Latent Semantic Analysis",
"luhn": "Luhn's Algorithm",
"bart": "BART",
"openai": "OpenAI",
}
plots = dict()
# Create plot for average scores
x = [method_labels[r["method"]] for r in results]
rouge1_bar = create_bar("Average ROUGE-1 Score", [r["avg_rouge1"] for r in results])
rouge2_bar = create_bar("Average ROUGE-2 Score", [r["avg_rouge2"] for r in results])
rougeL_bar = create_bar("Average ROUGE-L Score", [r["avg_rougeL"] for r in results])
rougeLsum_bar = create_bar("Average ROUGE-Lsum Score", [r["avg_rougeLsum"] for r in results])
avg_score_plot = plot_bar_data(
rouge1_bar,
rouge2_bar,
rougeL_bar,
rougeLsum_bar,
x=x,
title="Average Scores"
)
plots["average_score_plot"] = avg_score_plot
# Create plot for EM and F1 scores over multiple trials
for r in results:
x = [f"Sample #{i+1}" for i in range(len(r["rouge1_scores"]))]
rouge1_bar = create_bar("ROUGE-1 Score", r["rouge1_scores"])
rouge2_bar = create_bar("ROUGE-2 Score", r["rouge2_scores"])
rougeL_bar = create_bar("ROUGE-L Score", r["rougeL_scores"])
rougeLsum_bar = create_bar("ROUGE-Lsum Score", r["rougeLsum_scores"])
new_plot = plot_bar_data(
rouge1_bar,
rouge2_bar,
rougeL_bar,
rougeLsum_bar,
x=x,
title=f"ROUGE Scores for {method_labels[r['method']]}"
)
plots[f"{r['method']}_plot"] = new_plot
return plots
def method_evaluator(methods, num_trials=10, dataset_size=50):
"""
Evaluate several info extraction methods at once.
"""
# Initialise results object
results = [
{
"rouge1_scores": [],
"rouge2_scores": [],
"rougeL_scores": [],
"rougeLsum_scores": [],
"method": m
}
for m in methods
]
# Load squad evaluator
rouge_metric = load("rouge")
for t in range(num_trials):
print(f"Trial #{t+1}")
arts_ds = load_article_data(dataset_size)
for i, m in enumerate(methods):
result = evaluate_method(arts_ds, m, rouge_metric)
results[i]["rouge1_scores"].append(result.get("rouge1", None))
results[i]["rouge2_scores"].append(result.get("rouge2", None))
results[i]["rougeL_scores"].append(result.get("rougeL", None))
results[i]["rougeLsum_scores"].append(result.get("rougeLsum", None))
for i, _ in enumerate(results):
results[i]["avg_rouge1"] = np.mean(results[i]["rouge1_scores"])
results[i]["avg_rouge2"] = np.mean(results[i]["rouge2_scores"])
results[i]["avg_rougeL"] = np.mean(results[i]["rougeL_scores"])
results[i]["avg_rougeLsum"] = np.mean(results[i]["rougeLsum_scores"])
return results
# Get results of evaluation of each text summarisation method
results = method_evaluator(["se","lexR","texR", "lsa", "luhn", "bart", "openai"], 10, 10)
results
Trial #1
Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #2
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #3
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #4
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #5
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #6
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #7
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #8
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #9
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #10
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
[{'rouge1_scores': [0.19005207045312772,
0.3009871489354943,
0.2931411386426386,
0.19913127061547475,
0.34747150522244186,
0.21780695345713957,
0.19677072590816752,
0.20263345313430203,
0.2690306058363695,
0.28550823766596356],
'rouge2_scores': [0.05101505975469495,
0.1510140777924998,
0.11036392432707451,
0.08390162468206974,
0.13362014789671092,
0.09364049594907897,
0.07304490799056017,
0.06667913012653545,
0.1029041831157414,
0.10269853575799712],
'rougeL_scores': [0.10877960185826435,
0.20718498367278465,
0.17253018197361264,
0.13674771996945978,
0.21448509145713684,
0.16357187299206488,
0.11672870602951792,
0.1308710281363309,
0.16974310316394442,
0.17389115995035725],
'rougeLsum_scores': [0.15683948769685885,
0.25942774041970074,
0.24925089405353135,
0.1751798031402717,
0.2950106126222376,
0.19089009701889204,
0.15711285696841049,
0.17632973425353654,
0.228663830470195,
0.23856262102391734],
'method': 'se',
'avg_rouge1': 0.25025331098711195,
'avg_rouge2': 0.0968882087392963,
'avg_rougeL': 0.15945334492034735,
'avg_rougeLsum': 0.21272676776675511},
{'rouge1_scores': [0.3405239212045347,
0.35021329991096484,
0.34403613624836726,
0.31355456825794314,
0.2977057278641805,
0.31436499300731013,
0.3193784018022092,
0.3143211915200329,
0.30533178593278326,
0.3011859771768988],
'rouge2_scores': [0.0836205519395079,
0.14895156602979334,
0.13141695460786812,
0.11981384227043813,
0.06250807744248804,
0.09979772563901046,
0.11955255577092,
0.10844337387345118,
0.11594521952763542,
0.07749686171462468],
'rougeL_scores': [0.19011457862040881,
0.24242928810362477,
0.2009969199180459,
0.20742533408763114,
0.1499218846175273,
0.18951869638173008,
0.20976237857408786,
0.1986029612037844,
0.18732339073451992,
0.18217709975828067],
'rougeLsum_scores': [0.2607359158079007,
0.2988947193997544,
0.29232716337139675,
0.27938597600333825,
0.23157858388273223,
0.2711325074816871,
0.26774145197514143,
0.264836780716986,
0.25456628476817234,
0.23720013161575046],
'method': 'lexR',
'avg_rouge1': 0.3200616002925224,
'avg_rouge2': 0.10675467288157374,
'avg_rougeL': 0.19582725319996408,
'avg_rougeLsum': 0.2658399515022859},
{'rouge1_scores': [0.3121568041466592,
0.35862514994910183,
0.27465413628777746,
0.2845021964462412,
0.32764938437467706,
0.3495056760365899,
0.3026091317799985,
0.2637257129991947,
0.2848635717579817,
0.29249166946396443],
'rouge2_scores': [0.09128041649272774,
0.14248277647491403,
0.082930853032753,
0.09481728673520975,
0.12270370160995808,
0.11159554755741069,
0.10649193833458886,
0.0723151741108277,
0.09590014701357881,
0.0956953065933916],
'rougeL_scores': [0.16341548824541974,
0.2297436618616151,
0.15952422740822397,
0.18312088111458474,
0.2004944452882812,
0.21661958027941633,
0.2048016085995933,
0.17789767630935263,
0.18141941234970754,
0.16610399104565904],
'rougeLsum_scores': [0.2501236310270892,
0.3130678084938703,
0.22178242013640337,
0.24618963769205304,
0.27493036985149,
0.29154518658408,
0.26018570976406363,
0.2248939444361857,
0.23125894392449653,
0.23852494385415224],
'method': 'texR',
'avg_rouge1': 0.3050783433242186,
'avg_rouge2': 0.10162131479553602,
'avg_rougeL': 0.18831409725018536,
'avg_rougeLsum': 0.2552502595763884},
{'rouge1_scores': [0.28268831978769515,
0.24593303296837948,
0.3341495560893797,
0.1898583372324946,
0.3446262107849064,
0.23625997994066864,
0.25149293399100614,
0.3266258754178136,
0.23816089311922645,
0.27830661975692406],
'rouge2_scores': [0.06269053501928129,
0.05864786316923982,
0.11822365955548411,
0.026593796775168026,
0.13275093979686703,
0.0637347342365908,
0.06869927299005817,
0.1018317173374722,
0.05548650223344339,
0.05092124909665358],
'rougeL_scores': [0.16762835530919334,
0.1395663162266797,
0.21592338354645318,
0.105276830989179,
0.23359492174148477,
0.16122704677438973,
0.16618220906170983,
0.1852614639381589,
0.13573311394383797,
0.1500375707586039],
'rougeLsum_scores': [0.23745822143504625,
0.22313145770059323,
0.2919670979126927,
0.14868460527513774,
0.29444390840136064,
0.20345386412577832,
0.2230573509587108,
0.25612724842423285,
0.18972113883384373,
0.21805296461483073],
'method': 'lsa',
'avg_rouge1': 0.2728101759088494,
'avg_rouge2': 0.07395802702102586,
'avg_rougeL': 0.16604312122896903,
'avg_rougeLsum': 0.2286097857682227},
{'rouge1_scores': [0.33675015930904845,
0.3947931081288696,
0.37951040541191833,
0.27275276073134663,
0.33909427771756034,
0.3953526836411452,
0.3030567054715033,
0.27362116401176073,
0.33338511375415203,
0.3260083567680311],
'rouge2_scores': [0.09595906160801722,
0.17369944869324275,
0.16098788407331424,
0.09541625464520659,
0.11878552529135661,
0.18906358964170636,
0.11057661453801726,
0.10710722571984196,
0.1483535184143006,
0.12516828655085013],
'rougeL_scores': [0.19397054481679982,
0.241370074327719,
0.228814283466252,
0.18354602706953016,
0.1965069207800138,
0.269819568647694,
0.19574501753576534,
0.1935981673390528,
0.21020199012138555,
0.21379943725913003],
'rougeLsum_scores': [0.27334584192007294,
0.33454708156775836,
0.326802571866326,
0.24408001082963504,
0.2788017755657494,
0.35623044391952974,
0.2616350466126003,
0.24587814175278466,
0.2743842171730406,
0.28291636802255127],
'method': 'luhn',
'avg_rouge1': 0.33543247349453353,
'avg_rouge2': 0.13251174091758539,
'avg_rougeL': 0.21273720313633424,
'avg_rougeLsum': 0.28786214992300485},
{'rouge1_scores': [0.3939009459128039,
0.37254027357525565,
0.38511676285124147,
0.36371616182705696,
0.36525003486869606,
0.45383367199170754,
0.3912525380885452,
0.37559130516021416,
0.40166591761632986,
0.3862072712544443],
'rouge2_scores': [0.16402405590056096,
0.14793813141105328,
0.16945955011473235,
0.15095720000323404,
0.16339718798791314,
0.2376879923468648,
0.16588877028170657,
0.14972176330805426,
0.2111909198935441,
0.17020477017073637],
'rougeL_scores': [0.246439721981478,
0.23174367936284732,
0.2330795564553087,
0.24307607507124507,
0.25179734564383005,
0.3412241946299096,
0.2656396243094433,
0.21865501142765595,
0.2891048661926545,
0.2398244729324875],
'rougeLsum_scores': [0.3291652769699983,
0.31669453550047727,
0.3446023349546971,
0.30893306742931925,
0.31184284197167844,
0.4078493571781751,
0.326199718031569,
0.32155151155958994,
0.33957989656527005,
0.31689450282461157],
'method': 'bart',
'avg_rouge1': 0.38890748831462957,
'avg_rouge2': 0.17304703414184,
'avg_rougeL': 0.256058454800686,
'avg_rougeLsum': 0.33233130429853863},
{'rouge1_scores': [0.3539093887692436,
0.3629555492204734,
0.35994663094912205,
0.28893216632908786,
0.3653509550428032,
0.41770094821832787,
0.34936475914516085,
0.39112218207410765,
0.373265570806523,
0.34921358756159326],
'rouge2_scores': [0.10952401949543962,
0.1270272335224194,
0.12171058786165731,
0.08539170871758173,
0.11252715858134962,
0.18810111587872688,
0.1374317215150867,
0.14166691145414134,
0.1273621005602611,
0.13563433391025168],
'rougeL_scores': [0.21091417544004337,
0.2280331154950862,
0.20531569172470648,
0.1989245946442606,
0.21211201272113994,
0.29120842552422943,
0.24589005911983602,
0.21657056642271738,
0.24164466401099657,
0.21742281064880298],
'rougeLsum_scores': [0.29818294928079536,
0.303779430942121,
0.2886420001102047,
0.23832269856913707,
0.29695736595255895,
0.3658158836108653,
0.3009035166792679,
0.3149817685636841,
0.30353845980949024,
0.28070721527106446],
'method': 'openai',
'avg_rouge1': 0.36117617381164424,
'avg_rouge2': 0.12863768914969154,
'avg_rougeL': 0.22680361157518192,
'avg_rougeLsum': 0.2991831288789189}]
# Get data visualisations of results
results_plots = visualise_results(results)
# Results for simple extractive summarisation
results_plots["se_plot"]
# Results for LexRank
results_plots["lexR_plot"]
# Results for TextRank
results_plots["texR_plot"]
# Results for latent semantic analysis
results_plots["lsa_plot"]
# Results for Luhn's algorithm
results_plots["luhn_plot"]
# Results for BART
results_plots["bart_plot"]
# Results for OpenAI
results_plots["openai_plot"]
# Average ROUGE scores for all methods
results_plots["average_score_plot"]